#ifndef __NN_MODEL_H__
#define __NN_MODEL_H__

#include "header.hpp"
#include "layer.hpp"
#include "loss.hpp"
#include "net.hpp"
#include <assert.h>


namespace nn {

    using NNpair = std::pair<double, Tensorlist *>;
    class ModelBase {
    private:
        string _name;
        LossBase *_loss;
        net* _core;


        virtual Tensorlist *_backward(const Dmatrix &grad) {
            return new Tensorlist(0);
        }

        LossBase *get_loss() {
            assert(_loss != NULL && "Loss have not defined.");
            return _loss;
        }

    public:
        ModelBase(net* core,LossBase *loss = NULL, const string &name = "") :_core(core), _loss(loss) {
            assignable_constraint(name);
            _name = name;
        };

        ~ModelBase() {};

        void set_model_name(const string &name) noexcept { _name = name; }

        string get_model_name() noexcept { return _name; }
        void addLayer(LayerBase* l){_core->add(l);}
        const net& get_net(){return *_core;}
        virtual Dmatrix forward(const Dmatrix &x) { return _core->forward(x); }

        /*
         * usage:
         *  计算模型的反向传播梯度值和损失
         * args:
         * - prediction: 预测值
         * - target:     目标值
         * returns:
         *  pair<double,Tensorlist>(loss,grads)
         *      - loss : 损失
         *      - grads: 包含所有层的梯度值
         */
        NNpair backward(const Dmatrix &prediction, const Dmatrix &target) {
            double loss = _loss->loss(prediction, target);
            Dmatrix grad = _loss->grad(prediction, target);
            Tensorlist *grads = _core->backward(grad);
            return NNpair(loss, grads);
        }

        void apply_grad(const Tensorlist *grads) {
            auto layer_list = _core->get_layerlist();
            auto il = layer_list.begin();
            Tensorlist::const_iterator it = grads->begin();

            while (il != layer_list.end() && it != grads->end()) {
                if ((*il)->layerTp() == liner){
                    Dmatrix &grad = (*il)->get_grad();
                    Dmatrix w = (*il)->get_w();
                    (*il)->set_w(w.rowwise() + *(*it));
                }
                it++;
                il++;
            }
        }


        void train(const Dmatrix &x, const Dmatrix y) {
            Dmatrix prediction = _core->forward(x);
            NNpair loss_grads = backward(prediction, y);
            apply_grad(loss_grads.second);
        }
    }; // class Model

    class Sequential : public ModelBase {

    private:
    public:
        Sequential(net* core, LossBase *loss = NULL)
                : ModelBase(core,loss, "Sequential")  {};

        Sequential(LossBase *loss = NULL, string name = "Sequential")
                : ModelBase(new net(0,NULL),loss, name){};

        Dmatrix forward(const Dmatrix &x){
            net core = get_net();
            return core.forward(x);
        }
        ~Sequential() {};

    }; // class Sequential
} // namespace nn
#endif